#This script contains code for a function (effect_plotter) 
#that takes a regression table output and creates an effect plot,
#using ggplot2.

#The functions is designed in particular for regressions with 
#factor variables, such that the variable label and 
#all levels (including reference) are displayed in the plot.

#This function was written by Kirk Bansak, updated on
#March 5, 2018.
#This function was modified extensively by Ala' Alrababa'h to work with observational data instead of conjoints
#This function was modified by Daniel Masterson on August 7, 2020
#to include 90% CIs in addition to 95% confidence intervals
#Updated again by Ala' to take dof into account when calculating z scores (rather than assuming large n and choosing 1.96 or 1.645)

library(ggplot2)


# Arguments ---------------------------------------------------------------

#1. model.output: coeftest output object, using desired variance estimator

#2. names.variables: vector of the variable names (strings) to display, 
#in same order as model

#3. names.levels: list of vectors of levels (strings), incl. reference,
#for each variable to display, with levels in same order as model

#4. effect.label: string to display describing the x-axis

#5. x.lower: number for x-axis lower bound (NULL uses ggplot default)

#6. x.upper: number for x-axis upper bound (NULL uses ggplot default)

#7. labs: If FALSE, this removes the conjoint levels

# Function ----------------------------------------------------------------


effect_plotter <- function(estimate_vec,
                           se_vec,
                           dof,
                           names.variables,
                           names.levels,
                           effect.label,
                           x.lower = NULL,x.upper = NULL, labs = F, title = "ABC"){

  if(length(names.variables) != length(names.levels)){
    stop("Number of sets of levels does not match number of variables!")
  }
  
  z95 = qt(.975, dof)
  z90 = qt(.95, dof)
  
  n.vars <- length(names.variables)
  name <- c()
  code <- c()
  group <- c()
  pe <- c()
  se <- c()
  solo <- c()
  
  for (i in 1:n.vars){
    
    nv <- names.variables[i]
    nl <- names.levels[[i]]
  
    if(length(nl) > 1){
      name <- c(name,nv,nl,NA)
    } else {
      name <- c(name,nv,NA)
    }
    
    code.tmp <- c()
    solo.tmp <- c()
    if (length(nl) > 1){
      code.tmp <- c(0,0,rep(1,length(nl) - 1),0)
      solo.tmp <- c(0,0,rep(0,length(nl) - 1),0)
    } else {
      code.tmp <- c(1,0)
      solo.tmp <- c(1,0)
    }
    code <- c(code,code.tmp)
    solo <- c(solo,solo.tmp)
    
    group.tmp <- c()
    if (length(nl) > 1){
      group.tmp <- c("empty",rep(paste("group",i,sep=""),
                                 length(nl)),"empty")
    } else {
      group.tmp <- c(1,"empty")
    }
    group <- c(group,group.tmp)

    pe.tmp <- c()
    se.tmp <- c()
    if (length(nl) > 1){
      pe.tmp <- c(NA,rep(0,length(nl)),NA)
      se.tmp <- c(NA,rep(0,length(nl)),NA)
    } else {
      pe.tmp <- c(0,NA)
      se.tmp <- c(0,NA)
    }
    pe <- c(pe,pe.tmp)
    se <- c(se,se.tmp)
    
  }
  
  code[code==1] <- seq(1:sum(code))
  pdat <- data.frame(name,group,code,pe,se,solo)
  #pdat <- pdat[-nrow(pdat),]
  pdat$order <- rev(seq(1:nrow(pdat)))
  pdat$name = as.character(pdat$name)
  pdat$name[is.na(pdat$name)] = "X"
  pdat = pdat %>% filter(name != "")
  pdat$name[pdat$name == "X"] = NA
  #Extract pe and se estimates
  pevec <- estimate_vec
  sevec <- se_vec
  
  
  #Input the estimates into the framework dataframe
  for (i in 1:max(pdat$code)){
    pdat$pe[pdat$code == i] <- pevec[i]
    # +1 because of intercept
    pdat$se[pdat$code == i] <- sevec[i]
    pdat$z95[pdat$code == i] <- z95[i]
    pdat$z90[pdat$code == i] <- z90[i]
  }
  
  pdat$name <- as.character(pdat$name)
  pdat$group <- as.character(pdat$group)
  pdat$name[is.na(pdat$name)] <- ""
  
  if (!is.null(x.lower) & !is.null(x.upper)){
    theplot <- plotit(d = pdat, effect.label = effect.label,
                      x.lower = x.lower, x.upper = x.upper, labs = labs, title = title)
  } else {
    theplot <- plotit(d = pdat, effect.label = effect.label, labs = labs, title = title)
  }
  
  return(theplot)
  
}



plotit <- function(d, effect.label,x.lower = NULL,x.upper = NULL, labs, title, facet = NULL){  
  
  if(labs == F){
    lab = "element_blank()"
    h_just = 0.4 #title position
  }else{
    lab = "element_text(size = base_size, hjust = 0 , vjust=.5, face = rev(plot.face))"
    h_just = 0.55
  }
  
  
  CIs <- function(d){
    d$upper90 <-d$pe + d$z90*d$se
    d$lower90 <-d$pe - d$z90*d$se
    d$upper95 <-d$pe + d$z95*d$se
    d$lower95 <-d$pe - d$z95*d$se
    return(d)
  }
  d<- CIs(d)
  
  plot.labels <- as.character(d$name)
  plot.labels[d$group != "empty" & 
                !is.na(plot.labels) & d$solo == 0] <- 
    paste("    ",
          plot.labels[d$group != "empty" & 
                        !is.na(plot.labels) & d$solo == 0], sep = "")
  plot.face <- ifelse(is.na(d$pe) | d$solo == 1, "bold", "plain")
  d$ref <- "normal" #ifelse(d$pe != 0 | is.na(d$pe), "normal", "reference")
  
  d$order <- as.factor(d$order)
  
  #Construct plot
  p <- ggplot(d,aes(y=pe,x=order,color=group,size=0.3,shape=ref)) + 
    scale_shape_manual(values = c(16,1))

  if (!is.null(x.lower) & !is.null(x.upper)){
    p <- p + coord_flip(ylim = c(x.lower,x.upper)) 
  } else {
    p <- p + coord_flip() 
  }
  
  if(!is.null(facet)){
    p <- p + facet_wrap(facet)
  }
  
  
  p <- p + ylab(effect.label)
  p <- p + geom_hline(yintercept = 0,size=.5,colour="darkgrey",linetype=1) 
  p <- p + geom_pointrange(aes(ymin=lower90,ymax=upper90),
                           position="dodge", size = 0.65)
  p <- p + geom_pointrange(aes(ymin=lower95, ymax=upper95, alpha = I(0.65)),
                           position="dodge", size = 0.5)
  p <- p + scale_colour_discrete("Attribute:") + 
    scale_x_discrete(name="",labels=rev(plot.labels)) 

  theme_bw1 <- function(base_size = 8, base_family = "") {
    theme_economist(base_size = base_size, base_family = base_family)
      theme(
        axis.text.x = element_text(size = base_size, hjust = .5 , vjust=1, colour = "black"),
        axis.text.y = eval(parse(text = lab)),
        axis.ticks.y = element_line(colour = "grey50"),
        axis.ticks.x = element_blank(),
        axis.title.y = element_text(size = base_size+2,angle=90,
                                    vjust=.01,hjust=.1),
        axis.title.x = element_text(size = base_size),
        legend.position = "none"
      )
  }
  base_size = 8
  base_family = ""
  p <- p + theme_economist() + theme(axis.text.x = element_text(size = base_size, hjust = .5 , vjust=1, colour = "black"),
                                      axis.text.y = eval(parse(text = lab)),
                                      #axis.ticks.y = element_line(colour = "grey50"),
                                      axis.ticks.x = element_blank(),
                                      axis.title.y = element_text(size = base_size+1,angle=90,
                                                                  vjust=.01,hjust=.1),
                                      axis.title.x = element_text(size = base_size+1),
                                      legend.position = "none",
                                     plot.title = element_text(size = base_size+2, face = "bold", hjust = h_just))
  p <- p + theme(panel.grid.major = element_line(size = 0.28)) + 
    theme(panel.grid.minor = element_blank())
  p <- p + labs(title = title)
    theme(title = element_text(size = 5))
  
    return(p)
  
}
